import numpy as np
import pickle, struct, socket, math


def get_even_odd_from_one_hot_label(label):
    for i in range(0, len(label)):
        if label[i] == 1:
            c = i % 2
            if c == 0:
                c = 1
            elif c == 1:
                c = -1
            return c


def get_index_from_one_hot_label(label):
    for i in range(0, len(label)):
        if label[i] == 1:
            return [i]


def get_one_hot_from_label_index(label, number_of_labels=10):
    one_hot = np.zeros(number_of_labels)
    one_hot[label] = 1
    return one_hot

'''
def send_msg(sock, msg):
    msg_pickle = pickle.dumps(msg)
    sock.sendall(struct.pack(">I", len(msg_pickle)))
    sock.sendall(msg_pickle)
    print(msg[0], 'sent to', sock.getpeername())


def recv_msg(sock, expect_msg_type=None):
    msg_len = struct.unpack(">I", sock.recv(4))[0]
    msg = sock.recv(msg_len, socket.MSG_WAITALL)
    msg = pickle.loads(msg)
    print(msg[0], 'received from', sock.getpeername())

    if (expect_msg_type is not None) and (msg[0] != expect_msg_type):
        raise Exception("Expected " + expect_msg_type + " but received " + msg[0])
    return msg
'''

def send_msg(sock, msg):
    msg_pickle = pickle.dumps(msg)
    sock.sendall(struct.pack(">I", len(msg_pickle)))
    sock.sendall(msg_pickle)
    print(msg[0], 'sent to', sock.getpeername())

def recv_msg(sock, expect_msg_type=None):
    try:

        msg_len_data = sock.recv(4, socket.MSG_WAITALL)
        if len(msg_len_data) < 4:
            raise Exception("Failed to receive the complete message length")

        msg_len = struct.unpack(">I", msg_len_data)[0]
        

        msg_data = sock.recv(msg_len, socket.MSG_WAITALL)
        if len(msg_data) < msg_len:
            raise Exception("Failed to receive the complete message data")

        msg = pickle.loads(msg_data)
        print(msg[0], 'received from', sock.getpeername())

        if (expect_msg_type is not None) and (msg[0] != expect_msg_type):
            raise Exception(f"Expected {expect_msg_type} but received {msg[0]}")

        return msg
    except Exception as e:
        print(f"Error receiving message: {e}")
        raise
    


def moving_average(param_mvavr, param_new, movingAverageHoldingParam):
    if param_mvavr is None or np.isnan(param_mvavr):
        param_mvavr = param_new
    else:
        if not np.isnan(param_new):
            param_mvavr = movingAverageHoldingParam * param_mvavr + (1 - movingAverageHoldingParam) * param_new
    return param_mvavr


def get_indices_each_node_case(n_nodes, maxCase, label_list):
    indices_each_node_case = []

    for i in range(0, maxCase):
        indices_each_node_case.append([])

    for i in range(0, n_nodes):
        for j in range(0, maxCase):
            indices_each_node_case[j].append([])

    min_label = min(label_list)
    max_label = max(label_list)
    num_labels = max_label - min_label + 1

    for i in range(0, len(label_list)):
        # case 1
        indices_each_node_case[0][(i % n_nodes)].append(i)

        # case 2
        tmp_target_node = int((label_list[i] - min_label) % n_nodes)
        if n_nodes > num_labels:
            tmp_min_index = 0
            tmp_min_val = math.inf
            for n in range(0, n_nodes):
                if n % num_labels == tmp_target_node and len(indices_each_node_case[1][n]) < tmp_min_val:
                    tmp_min_val = len(indices_each_node_case[1][n])
                    tmp_min_index = n
            tmp_target_node = tmp_min_index
        indices_each_node_case[1][tmp_target_node].append(i)

        # case 3
        for n in range(0, n_nodes):
            indices_each_node_case[2][n].append(i)

        # case 4
        tmp = int(np.ceil(min(n_nodes, num_labels) / 2))
        if label_list[i] < (min_label + max_label) / 2:
            tmp_target_node = i % tmp
        elif n_nodes > 1:
            tmp_target_node = int(((label_list[i] - min_label) % (min(n_nodes, num_labels) - tmp)) + tmp)

        if n_nodes > num_labels:
            tmp_min_index = 0
            tmp_min_val = math.inf
            for n in range(0, n_nodes):
                if n % num_labels == tmp_target_node and len(indices_each_node_case[3][n]) < tmp_min_val:
                    tmp_min_val = len(indices_each_node_case[3][n])
                    tmp_min_index = n
            tmp_target_node = tmp_min_index

        indices_each_node_case[3][tmp_target_node].append(i)

    # case 5
    case5_distribution = [0.1, 0.1, 0.2, 0.2, 0.4]
    label_counts = {label: 0 for label in range(num_labels)}
    for label in label_list:
        label_counts[label] += 1
    sorted_labels = sorted(label_counts, key=label_counts.get, reverse=True)
    node_indices = np.arange(n_nodes)
    np.random.shuffle(node_indices)
    current_node = 0
    for label in sorted_labels:
        count = int(label_counts[label] * case5_distribution[current_node])
        indices_each_node_case[4][node_indices[current_node]].extend([label] * count)
        current_node = (current_node + 1) % len(case5_distribution)

    # case 6
    case6_distribution = [0.1, 0.1, 0.1, 0.1, 0.6]
    node_indices = np.arange(n_nodes)
    np.random.shuffle(node_indices)
    current_node = 0
    for label in sorted_labels:
        count = int(label_counts[label] * case6_distribution[current_node])
        indices_each_node_case[5][node_indices[current_node]].extend([label] * count)
        current_node = (current_node + 1) % len(case6_distribution)

    return indices_each_node_case





## 以上代码在get_indices_each_node_case函数中添加了新的标签分布案例case 5和case 6。case 5和case 6分别按照10%、10%、20%、20%、40%和10%、10%、10%、10%、60%的比例分配标签。